{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**this kernel contains my stage-2 submissions**\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Hello kagglers**\n",
    "this is my first kaggle competition,as a beginner what i tried is as follows :\n",
    "1.thanks @rishabhiitbhu for your public kernel,i used code from that kernel,which is this one : https://www.kaggle.com/rishabhiitbhu/unet-with-resnet34-encoder-pytorch\n",
    "2.changed his architecture from unet with resnet34 to unet with se_resnext50_32x4d\n",
    "\n",
    "in stage-1 i also tried a lot more different architectures in this kernel like linknet with resnet101,unet with resnet101,linknet with vgg11 etc etc,but none of them performed well than this unet with vgg11  in stage-1 public leaderboard,of course there were few architectures i tried like unet with senet with radams and i saw the graph is better than it was before,but as those encoders are large,i ran out of memory in kaggle kernel,if i had a gpu,i am optimistic those models would get us medal hahaha,anyway what are the models from segmentation model pytorch github ripository you  tried for this competition? and what the outcome of each?if you can remember please share with me in the comment box,it will help me a lot for my upcoming competitions\n",
    "\n",
    "\n",
    "**most importantly**\n",
    "1. @rishabhiitbhu is using equal number of pneumothorax and non pneumothorax samples for training,i have decided to increase the non-pneumothorax sample a bit to see if it does well in terms of overall prediction\n",
    "2. i have used radams here\n",
    "3. after 20th epoch which this model stopped training because : \"RuntimeError: DataLoader worker (pid(s) 1094) exited unexpectedly\" so i  reduced the size of num_works and  trained the network again\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "****This Kernel uses UNet architecture with se_resnext50_32x4d encoder, I've used [segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch) library which has many inbuilt segmentation architectures. This kernel is inspired by [Yury](https://www.kaggle.com/deyury)'s discussion thread [here](https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/99440#591985). I've used snippets from multiple other public kernels I've given due credits at the end of this notebook.\n",
    "\n",
    "What's down below?\n",
    "\n",
    "* UNet with imagenet pretrained se_resnext50_32x4d architecture\n",
    "* Training on 512x512 sized images/masks with Standard Augmentations\n",
    "* MixedLoss (weighted sum of Focal loss and dice loss)\n",
    "* Gradient Accumulution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**ChangeLog**\n",
    "1. version 1 contains my exact model for this competition's stage 2 submission\n",
    "2. in version 3 i will try speckle noise or multiplicative noise (implemented in albumentation few days ago after i requested for the implementation),check here : https://github.com/albu/albumentations/issues/439\n",
    "3. in version 4 i will try fpn instead of unet\n",
    "4. version 7 - FPN with inceptionresnetv2 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U git+https://github.com/albu/albumentations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import pdb\n",
    "import time\n",
    "import warnings\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "import torch.optim as optim\n",
    "import torch.backends.cudnn as cudnn\n",
    "from torch.utils.data import DataLoader, Dataset, sampler\n",
    "from matplotlib import pyplot as plt\n",
    "from albumentations import (HorizontalFlip,OpticalDistortion,VerticalFlip,GridDistortion,RandomBrightnessContrast,OneOf,ElasticTransform,RandomGamma,IAAEmboss,Blur,RandomRotate90,Transpose, ShiftScaleRotate, Normalize, Resize, Compose, GaussNoise,MultiplicativeNoise)\n",
    "from albumentations.pytorch import ToTensor\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import albumentations as A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A.MultiplicativeNoise()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/qubvel/segmentation_models.pytorch > /dev/null 2>&1 # Install segmentations_models.pytorch, with no bash output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import segmentation_models_pytorch as smp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utility functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_length_decode(rle, height=1024, width=1024, fill_value=1):\n",
    "    component = np.zeros((height, width), np.float32)\n",
    "    component = component.reshape(-1)\n",
    "    rle = np.array([int(s) for s in rle.strip().split(' ')])\n",
    "    rle = rle.reshape(-1, 2)\n",
    "    start = 0\n",
    "    for index, length in rle:\n",
    "        start = start+index\n",
    "        end = start+length\n",
    "        component[start: end] = fill_value\n",
    "        start = end\n",
    "    component = component.reshape(width, height).T\n",
    "    return component\n",
    "\n",
    "def run_length_encode(component):\n",
    "    component = component.T.flatten()\n",
    "    start = np.where(component[1:] > component[:-1])[0]+1\n",
    "    end = np.where(component[:-1] > component[1:])[0]+1\n",
    "    length = end-start\n",
    "    rle = []\n",
    "    for i in range(len(length)):\n",
    "        if i == 0:\n",
    "            rle.extend([start[0], length[0]])\n",
    "        else:\n",
    "            rle.extend([start[i]-end[i-1], length[i]])\n",
    "    rle = ' '.join([str(r) for r in rle])\n",
    "    return rle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SIIMDataset(Dataset):\n",
    "    def __init__(self, df, fnames, data_folder, size, mean, std, phase):\n",
    "        self.df = df\n",
    "        self.root = data_folder\n",
    "        self.size = size\n",
    "        self.mean = mean\n",
    "        self.std = std\n",
    "        self.phase = phase\n",
    "        self.transforms = get_transforms(phase, size, mean, std)\n",
    "        self.gb = self.df.groupby('ImageId')\n",
    "        self.fnames = fnames\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        image_id = self.fnames[idx]\n",
    "        df = self.gb.get_group(image_id)\n",
    "        annotations = df[' EncodedPixels'].tolist()\n",
    "        image_path = os.path.join(self.root, image_id + \".png\")\n",
    "        image = cv2.imread(image_path)\n",
    "        mask = np.zeros([1024, 1024])\n",
    "        if annotations[0] != '-1':\n",
    "            for rle in annotations:\n",
    "                mask += run_length_decode(rle)\n",
    "        mask = (mask >= 1).astype('float32') # for overlap cases\n",
    "        augmented = self.transforms(image=image, mask=mask)\n",
    "        image = augmented['image']\n",
    "        mask = augmented['mask']\n",
    "        return image, mask\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.fnames)\n",
    "\n",
    "\n",
    "def get_transforms(phase, size, mean, std):\n",
    "    list_transforms = []\n",
    "    if phase == \"train\":\n",
    "        list_transforms.extend(\n",
    "            [\n",
    "                 HorizontalFlip(p=0.5),\n",
    "                ShiftScaleRotate(\n",
    "                    shift_limit=0,  # no resizing\n",
    "                    scale_limit=0.1,\n",
    "                    rotate_limit=10, # rotate\n",
    "                    p=0.5,\n",
    "                    border_mode=cv2.BORDER_CONSTANT\n",
    "                ),\n",
    "                 GaussNoise(),\n",
    "                A.MultiplicativeNoise(multiplier=1.5, p=1),\n",
    "            ]\n",
    "        )\n",
    "    list_transforms.extend(\n",
    "        [\n",
    "            Resize(size, size),\n",
    "            Normalize(mean=mean, std=std, p=1),\n",
    "            ToTensor(),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    list_trfms = Compose(list_transforms)\n",
    "    return list_trfms\n",
    "\n",
    "def provider(\n",
    "    fold,\n",
    "    total_folds,\n",
    "    data_folder,\n",
    "    df_path,\n",
    "    phase,\n",
    "    size,\n",
    "    mean=None,\n",
    "    std=None,\n",
    "    batch_size=8,\n",
    "    num_workers=2,\n",
    "):\n",
    "    df_all = pd.read_csv(df_path)\n",
    "    df = df_all.drop_duplicates('ImageId')\n",
    "    df_with_mask = df[df[\" EncodedPixels\"] != \"-1\"]\n",
    "    df_with_mask['has_mask'] = 1\n",
    "    df_without_mask = df[df[\" EncodedPixels\"] == \"-1\"]\n",
    "    df_without_mask['has_mask'] = 0\n",
    "    df_without_mask_sampled = df_without_mask.sample(len(df_with_mask)+1500, random_state=2019) # random state is imp\n",
    "    df = pd.concat([df_with_mask, df_without_mask_sampled])\n",
    "    \n",
    "    #NOTE: equal number of positive and negative cases are chosen.\n",
    "    \n",
    "    kfold = StratifiedKFold(total_folds, shuffle=True, random_state=43)\n",
    "    train_idx, val_idx = list(kfold.split(df[\"ImageId\"], df[\"has_mask\"]))[fold]\n",
    "    train_df, val_df = df.iloc[train_idx], df.iloc[val_idx]\n",
    "    df = train_df if phase == \"train\" else val_df\n",
    "    # NOTE: total_folds=5 -> train/val : 80%/20%\n",
    "    \n",
    "    fnames = df['ImageId'].values\n",
    "    \n",
    "    image_dataset = SIIMDataset(df_all, fnames, data_folder, size, mean, std, phase)\n",
    "\n",
    "    dataloader = DataLoader(\n",
    "        image_dataset,\n",
    "        batch_size=batch_size,\n",
    "        num_workers=num_workers,\n",
    "        pin_memory=True,\n",
    "        shuffle=True,\n",
    "    )\n",
    "    return dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "print(os.listdir('../input/'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_submission_path = '../input/siimmy/stage_2_sample_submission.csv'\n",
    "train_rle_path = '../input/mysiim/train-rle.csv'\n",
    "data_folder = \"../input/siimpng/siimpng/train_png\"\n",
    "test_data_folder = \"../input/siim_stage2_png\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "ab = glob.glob('../input/siimpng/siimpng/train_png/*.png')\n",
    "len(ab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a= pd.read_csv(train_rle_path)\n",
    "len(a)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataloader sanity check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = provider(\n",
    "    fold=0,\n",
    "    total_folds=5,\n",
    "    data_folder=data_folder,\n",
    "    df_path=train_rle_path,\n",
    "    phase=\"train\",\n",
    "    size=512,\n",
    "    mean = (0.485, 0.456, 0.406),\n",
    "    std = (0.229, 0.224, 0.225),\n",
    "    batch_size=16,\n",
    "    num_workers=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(dataloader)) # get a batch from the dataloader\n",
    "images, masks = batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot some random images in the `batch`\n",
    "idx = random.choice(range(16))\n",
    "plt.imshow(images[idx][0], cmap='bone')\n",
    "plt.imshow(masks[idx][0], alpha=0.2, cmap='Reds')\n",
    "plt.show()\n",
    "if len(np.unique(masks[idx][0])) == 1: # only zeros\n",
    "    print('Chosen image has no ground truth mask, rerun the cell')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Losses\n",
    "\n",
    "This kernel uses a weighted sum of Focal Loss and Dice Loss, let's call it MixedLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dice_loss(input, target):\n",
    "    input = torch.sigmoid(input)\n",
    "    smooth = 1.0\n",
    "    iflat = input.view(-1)\n",
    "    tflat = target.view(-1)\n",
    "    intersection = (iflat * tflat).sum()\n",
    "    return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))\n",
    "\n",
    "\n",
    "class FocalLoss(nn.Module):\n",
    "    def __init__(self, gamma):\n",
    "        super().__init__()\n",
    "        self.gamma = gamma\n",
    "\n",
    "    def forward(self, input, target):\n",
    "        if not (target.size() == input.size()):\n",
    "            raise ValueError(\"Target size ({}) must be the same as input size ({})\"\n",
    "                             .format(target.size(), input.size()))\n",
    "        max_val = (-input).clamp(min=0)\n",
    "        loss = input - input * target + max_val + \\\n",
    "            ((-max_val).exp() + (-input - max_val).exp()).log()\n",
    "        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))\n",
    "        loss = (invprobs * self.gamma).exp() * loss\n",
    "        return loss.mean()\n",
    "\n",
    "\n",
    "class MixedLoss(nn.Module):\n",
    "    def __init__(self, alpha, gamma):\n",
    "        super().__init__()\n",
    "        self.alpha = alpha\n",
    "        self.focal = FocalLoss(gamma)\n",
    "\n",
    "    def forward(self, input, target):\n",
    "        loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target))\n",
    "        return loss.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Some more utility functions\n",
    "\n",
    "Here are some utility functions for calculating IoU and Dice scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(X, threshold):\n",
    "    X_p = np.copy(X)\n",
    "    preds = (X_p > threshold).astype('uint8')\n",
    "    return preds\n",
    "\n",
    "def metric(probability, truth, threshold=0.5, reduction='none'):\n",
    "    '''Calculates dice of positive and negative images seperately'''\n",
    "    '''probability and truth must be torch tensors'''\n",
    "    batch_size = len(truth)\n",
    "    with torch.no_grad():\n",
    "        probability = probability.view(batch_size, -1)\n",
    "        truth = truth.view(batch_size, -1)\n",
    "        assert(probability.shape == truth.shape)\n",
    "\n",
    "        p = (probability > threshold).float()\n",
    "        t = (truth > 0.5).float()\n",
    "\n",
    "        t_sum = t.sum(-1)\n",
    "        p_sum = p.sum(-1)\n",
    "        neg_index = torch.nonzero(t_sum == 0)\n",
    "        pos_index = torch.nonzero(t_sum >= 1)\n",
    "\n",
    "        dice_neg = (p_sum == 0).float()\n",
    "        dice_pos = 2 * (p*t).sum(-1)/((p+t).sum(-1))\n",
    "\n",
    "        dice_neg = dice_neg[neg_index]\n",
    "        dice_pos = dice_pos[pos_index]\n",
    "        dice = torch.cat([dice_pos, dice_neg])\n",
    "\n",
    "        dice_neg = np.nan_to_num(dice_neg.mean().item(), 0)\n",
    "        dice_pos = np.nan_to_num(dice_pos.mean().item(), 0)\n",
    "        dice = dice.mean().item()\n",
    "\n",
    "        num_neg = len(neg_index)\n",
    "        num_pos = len(pos_index)\n",
    "\n",
    "    return dice, dice_neg, dice_pos, num_neg, num_pos\n",
    "\n",
    "class Meter:\n",
    "    '''A meter to keep track of iou and dice scores throughout an epoch'''\n",
    "    def __init__(self, phase, epoch):\n",
    "        self.base_threshold = 0.5 # <<<<<<<<<<< here's the threshold\n",
    "        self.base_dice_scores = []\n",
    "        self.dice_neg_scores = []\n",
    "        self.dice_pos_scores = []\n",
    "        self.iou_scores = []\n",
    "\n",
    "    def update(self, targets, outputs):\n",
    "        probs = torch.sigmoid(outputs)\n",
    "        dice, dice_neg, dice_pos, _, _ = metric(probs, targets, self.base_threshold)\n",
    "        self.base_dice_scores.append(dice)\n",
    "        self.dice_pos_scores.append(dice_pos)\n",
    "        self.dice_neg_scores.append(dice_neg)\n",
    "        preds = predict(probs, self.base_threshold)\n",
    "        iou = compute_iou_batch(preds, targets, classes=[1])\n",
    "        self.iou_scores.append(iou)\n",
    "\n",
    "    def get_metrics(self):\n",
    "        dice = np.mean(self.base_dice_scores)\n",
    "        dice_neg = np.mean(self.dice_neg_scores)\n",
    "        dice_pos = np.mean(self.dice_pos_scores)\n",
    "        dices = [dice, dice_neg, dice_pos]\n",
    "        iou = np.nanmean(self.iou_scores)\n",
    "        return dices, iou\n",
    "\n",
    "def epoch_log(phase, epoch, epoch_loss, meter, start):\n",
    "    '''logging the metrics at the end of an epoch'''\n",
    "    dices, iou = meter.get_metrics()\n",
    "    dice, dice_neg, dice_pos = dices\n",
    "    print(\"Loss: %0.4f | dice: %0.4f | dice_neg: %0.4f | dice_pos: %0.4f | IoU: %0.4f\" % (epoch_loss, dice, dice_neg, dice_pos, iou))\n",
    "    return dice, iou\n",
    "\n",
    "def compute_ious(pred, label, classes, ignore_index=255, only_present=True):\n",
    "    '''computes iou for one ground truth mask and predicted mask'''\n",
    "    pred[label == ignore_index] = 0\n",
    "    ious = []\n",
    "    for c in classes:\n",
    "        label_c = label == c\n",
    "        if only_present and np.sum(label_c) == 0:\n",
    "            ious.append(np.nan)\n",
    "            continue\n",
    "        pred_c = pred == c\n",
    "        intersection = np.logical_and(pred_c, label_c).sum()\n",
    "        union = np.logical_or(pred_c, label_c).sum()\n",
    "        if union != 0:\n",
    "            ious.append(intersection / union)\n",
    "    return ious if ious else [1]\n",
    "\n",
    "\n",
    "def compute_iou_batch(outputs, labels, classes=None):\n",
    "    '''computes mean iou for a batch of ground truth masks and predicted masks'''\n",
    "    ious = []\n",
    "    preds = np.copy(outputs) # copy is imp\n",
    "    labels = np.array(labels) # tensor to np\n",
    "    for pred, label in zip(preds, labels):\n",
    "        ious.append(np.nanmean(compute_ious(pred, label, classes)))\n",
    "    iou = np.nanmean(ious)\n",
    "    return iou\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](http://)## FPN  with inceptionresnetv2 model\n",
    "Let's take a look at the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = smp.FPN('inceptionresnetv2', encoder_weights=\"imagenet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model # a *deeper* look"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Radams**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "from torch.optim.optimizer import Optimizer, required\n",
    "\n",
    "class RAdam(Optimizer):\n",
    "\n",
    "    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):\n",
    "        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)\n",
    "        self.buffer = [[None, None, None] for ind in range(10)]\n",
    "        super(RAdam, self).__init__(params, defaults)\n",
    "\n",
    "    def __setstate__(self, state):\n",
    "        super(RAdam, self).__setstate__(state)\n",
    "\n",
    "    def step(self, closure=None):\n",
    "\n",
    "        loss = None\n",
    "        if closure is not None:\n",
    "            loss = closure()\n",
    "\n",
    "        for group in self.param_groups:\n",
    "\n",
    "            for p in group['params']:\n",
    "                if p.grad is None:\n",
    "                    continue\n",
    "                grad = p.grad.data.float()\n",
    "                if grad.is_sparse:\n",
    "                    raise RuntimeError('RAdam does not support sparse gradients')\n",
    "\n",
    "                p_data_fp32 = p.data.float()\n",
    "\n",
    "                state = self.state[p]\n",
    "\n",
    "                if len(state) == 0:\n",
    "                    state['step'] = 0\n",
    "                    state['exp_avg'] = torch.zeros_like(p_data_fp32)\n",
    "                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)\n",
    "                else:\n",
    "                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)\n",
    "                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)\n",
    "\n",
    "                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n",
    "                beta1, beta2 = group['betas']\n",
    "\n",
    "                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n",
    "                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n",
    "\n",
    "                state['step'] += 1\n",
    "                buffered = self.buffer[int(state['step'] % 10)]\n",
    "                if state['step'] == buffered[0]:\n",
    "                    N_sma, step_size = buffered[1], buffered[2]\n",
    "                else:\n",
    "                    buffered[0] = state['step']\n",
    "                    beta2_t = beta2 ** state['step']\n",
    "                    N_sma_max = 2 / (1 - beta2) - 1\n",
    "                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)\n",
    "                    buffered[1] = N_sma\n",
    "\n",
    "                    # more conservative since it's an approximated value\n",
    "                    if N_sma >= 5:\n",
    "                        step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])\n",
    "                    else:\n",
    "                        step_size = group['lr'] / (1 - beta1 ** state['step'])\n",
    "                    buffered[2] = step_size\n",
    "\n",
    "                if group['weight_decay'] != 0:\n",
    "                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)\n",
    "\n",
    "                # more conservative since it's an approximated value\n",
    "                if N_sma >= 5:            \n",
    "                    denom = exp_avg_sq.sqrt().add_(group['eps'])\n",
    "                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)\n",
    "                else:\n",
    "                    p_data_fp32.add_(-step_size, exp_avg)\n",
    "\n",
    "                p.data.copy_(p_data_fp32)\n",
    "\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Training and validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Trainer(object):\n",
    "    '''This class takes care of training and validation of our model'''\n",
    "    def __init__(self, model):\n",
    "        self.fold = 1\n",
    "        self.total_folds = 5\n",
    "        self.num_workers = 4\n",
    "        self.batch_size = {\"train\": 4, \"val\": 4}\n",
    "        self.accumulation_steps = 32 // self.batch_size['train']\n",
    "        self.lr = 5e-4\n",
    "        self.num_epochs = 32\n",
    "        self.best_loss = float(\"inf\")\n",
    "        self.phases = [\"train\", \"val\"]\n",
    "        self.device = torch.device(\"cuda:0\")\n",
    "        torch.set_default_tensor_type(\"torch.cuda.FloatTensor\")\n",
    "        self.net = model\n",
    "        self.criterion = MixedLoss(10.0, 2.0)\n",
    "        #self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)\n",
    "        self.optimizer = RAdam(model.parameters(), lr=self.lr)\n",
    "\n",
    "        self.scheduler = ReduceLROnPlateau(self.optimizer, mode=\"min\", patience=3, verbose=True)\n",
    "        self.net = self.net.to(self.device)\n",
    "        cudnn.benchmark = True\n",
    "        self.dataloaders = {\n",
    "            phase: provider(\n",
    "                fold=1,\n",
    "                total_folds=5,\n",
    "                data_folder=data_folder,\n",
    "                df_path=train_rle_path,\n",
    "                phase=phase,\n",
    "                size=512,\n",
    "                mean=(0.485, 0.456, 0.406),\n",
    "                std=(0.229, 0.224, 0.225),\n",
    "                batch_size=self.batch_size[phase],\n",
    "                num_workers=self.num_workers,\n",
    "            )\n",
    "            for phase in self.phases\n",
    "        }\n",
    "        self.losses = {phase: [] for phase in self.phases}\n",
    "        self.iou_scores = {phase: [] for phase in self.phases}\n",
    "        self.dice_scores = {phase: [] for phase in self.phases}\n",
    "        \n",
    "    def forward(self, images, targets):\n",
    "        images = images.to(self.device)\n",
    "        masks = targets.to(self.device)\n",
    "        outputs = self.net(images)\n",
    "        loss = self.criterion(outputs, masks)\n",
    "        return loss, outputs\n",
    "\n",
    "    def iterate(self, epoch, phase):\n",
    "        meter = Meter(phase, epoch)\n",
    "        start = time.strftime(\"%H:%M:%S\")\n",
    "        print(f\"Starting epoch: {epoch} | phase: {phase} | ⏰: {start}\")\n",
    "        batch_size = self.batch_size[phase]\n",
    "        self.net.train(phase == \"train\")\n",
    "        dataloader = self.dataloaders[phase]\n",
    "        running_loss = 0.0\n",
    "        total_batches = len(dataloader)\n",
    "#         tk0 = tqdm(dataloader, total=total_batches)\n",
    "        self.optimizer.zero_grad()\n",
    "        for itr, batch in enumerate(dataloader):\n",
    "            images, targets = batch\n",
    "            loss, outputs = self.forward(images, targets)\n",
    "            loss = loss / self.accumulation_steps\n",
    "            if phase == \"train\":\n",
    "                loss.backward()\n",
    "                if (itr + 1 ) % self.accumulation_steps == 0:\n",
    "                    self.optimizer.step()\n",
    "                    self.optimizer.zero_grad()\n",
    "            running_loss += loss.item()\n",
    "            outputs = outputs.detach().cpu()\n",
    "            meter.update(targets, outputs)\n",
    "#             tk0.set_postfix(loss=(running_loss / ((itr + 1))))\n",
    "        epoch_loss = (running_loss * self.accumulation_steps) / total_batches\n",
    "        dice, iou = epoch_log(phase, epoch, epoch_loss, meter, start)\n",
    "        self.losses[phase].append(epoch_loss)\n",
    "        self.dice_scores[phase].append(dice)\n",
    "        self.iou_scores[phase].append(iou)\n",
    "        torch.cuda.empty_cache()\n",
    "        return epoch_loss\n",
    "\n",
    "    def start(self):\n",
    "        for epoch in range(self.num_epochs):\n",
    "            self.iterate(epoch, \"train\")\n",
    "            state = {\n",
    "                \"epoch\": epoch,\n",
    "                \"best_loss\": self.best_loss,\n",
    "                \"state_dict\": self.net.state_dict(),\n",
    "                \"optimizer\": self.optimizer.state_dict(),\n",
    "            }\n",
    "            val_loss = self.iterate(epoch, \"val\")\n",
    "            self.scheduler.step(val_loss)\n",
    "            if val_loss < self.best_loss:\n",
    "                print(\"******** New optimal found, saving state ********\")\n",
    "                state[\"best_loss\"] = self.best_loss = val_loss\n",
    "                torch.save(state, \"./model.pth\")\n",
    "            print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_trainer = Trainer(model)\n",
    "model_trainer.start()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PLOT TRAINING\n",
    "losses = model_trainer.losses\n",
    "dice_scores = model_trainer.dice_scores # overall dice\n",
    "iou_scores = model_trainer.iou_scores\n",
    "\n",
    "def plot(scores, name):\n",
    "    plt.figure(figsize=(15,5))\n",
    "    plt.plot(range(len(scores[\"train\"])), scores[\"train\"], label=f'train {name}')\n",
    "    plt.plot(range(len(scores[\"train\"])), scores[\"val\"], label=f'val {name}')\n",
    "    plt.title(f'{name} plot'); plt.xlabel('Epoch'); plt.ylabel(f'{name}');\n",
    "    plt.legend(); \n",
    "    plt.show()\n",
    "\n",
    "plot(losses, \"BCE loss\")\n",
    "plot(dice_scores, \"Dice score\")\n",
    "plot(iou_scores, \"IoU score\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TestDataset(Dataset):\n",
    "    def __init__(self, root, df, size, mean, std, tta=4):\n",
    "        self.root = root\n",
    "        self.size = size\n",
    "        self.fnames = list(df[\"ImageId\"])\n",
    "        self.num_samples = len(self.fnames)\n",
    "        self.transform = Compose(\n",
    "            [\n",
    "                Normalize(mean=mean, std=std, p=1),\n",
    "                Resize(size, size),\n",
    "                ToTensor(),\n",
    "            ]\n",
    "        )\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        fname = self.fnames[idx]\n",
    "        path = os.path.join(self.root, fname + \".png\")\n",
    "        image = cv2.imread(path)\n",
    "        images = self.transform(image=image)[\"image\"]\n",
    "        return images\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_samples\n",
    "\n",
    "def post_process(probability, threshold, min_size):\n",
    "    mask = cv2.threshold(probability, threshold, 1, cv2.THRESH_BINARY)[1]\n",
    "    num_component, component = cv2.connectedComponents(mask.astype(np.uint8))\n",
    "    predictions = np.zeros((1024, 1024), np.float32)\n",
    "    num = 0\n",
    "    for c in range(1, num_component):\n",
    "        p = (component == c)\n",
    "        if p.sum() > min_size:\n",
    "            predictions[p] = 1\n",
    "            num += 1\n",
    "    return predictions, num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size = 512\n",
    "mean = (0.485, 0.456, 0.406)\n",
    "std = (0.229, 0.224, 0.225)\n",
    "num_workers = 8\n",
    "batch_size = 16\n",
    "best_threshold = 0.5\n",
    "min_size = 3500\n",
    "device = torch.device(\"cuda:0\")\n",
    "df = pd.read_csv(sample_submission_path)\n",
    "testset = DataLoader(\n",
    "    TestDataset(test_data_folder, df, size, mean, std),\n",
    "    batch_size=batch_size,\n",
    "    shuffle=False,\n",
    "    num_workers=num_workers,\n",
    "    pin_memory=True,\n",
    ")\n",
    "model = model_trainer.net # get the model from model_trainer object\n",
    "model.eval()\n",
    "state = torch.load('./model.pth', map_location=lambda storage, loc: storage)\n",
    "model.load_state_dict(state[\"state_dict\"])\n",
    "encoded_pixels = []\n",
    "for i, batch in enumerate(tqdm(testset)):\n",
    "    preds = torch.sigmoid(model(batch.to(device)))\n",
    "    preds = preds.detach().cpu().numpy()[:, 0, :, :] # (batch_size, 1, size, size) -> (batch_size, size, size)\n",
    "    for probability in preds:\n",
    "        if probability.shape != (1024, 1024):\n",
    "            probability = cv2.resize(probability, dsize=(1024, 1024), interpolation=cv2.INTER_LINEAR)\n",
    "        predict, num_predict = post_process(probability, best_threshold, min_size)\n",
    "        if num_predict == 0:\n",
    "            encoded_pixels.append('-1')\n",
    "        else:\n",
    "            r = run_length_encode(predict)\n",
    "            encoded_pixels.append(r)\n",
    "df['EncodedPixels'] = encoded_pixels\n",
    "df.to_csv('submission.csv', columns=['ImageId', 'EncodedPixels'], index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "`segmentation_models_pytorch` has got many other segmentation models implemented, try them out :)\n",
    "\n",
    "I've learnt a lot from fellow kagglers, I've borrowed a lot of code from you guys, special shout-out to [@Abhishek](https://www.kaggle.com/abhishek), [@Yury](https://www.kaggle.com/deyury), [Heng](https://www.kaggle.com/hengck23), [Ekhtiar](https://www.kaggle.com/ekhtiar), [lafoss](https://www.kaggle.com/iafoss), [Siddhartha](https://www.kaggle.com/meaninglesslives) and many other kagglers :)\n",
    "\n",
    "Kaggle is <3"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
